{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fashion MNIST with `tf.keras` from Scratch\n",
    "\n",
    "This example demonstrates the workflow to create, train, and validate a \n",
    "TensorFlow `tf.keras` model, save it to HDF5 `.h5` model and convert it \n",
    "to Core ML `.mlmodel` format using the `tfcoreml` converter. For more\n",
    "examples, refer `test_tf_2x.py` file.\n",
    " \n",
    "Note: \n",
    "\n",
    "- This notebook was tested with following dependencies:\n",
    "\n",
    "```\n",
    "tensorflow==2.0.0\n",
    "coremltools==3.1\n",
    "tfcoreml==1.1\n",
    "```\n",
    "\n",
    "- Models from TensorFlow 2.0+ is supported only for `minimum_ios_deployment_target>='13'`.\n",
    "You can also use `coremltools.converters.tensorflow.convert()` \n",
    "instead of `tfcoreml.convert()` to convert your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W1101 14:00:52.328081 4735601984 __init__.py:74] TensorFlow version 2.0.0 detected. Last version known to be fully compatible is 1.14.0 .\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import tfcoreml\n",
    "\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare fashion_mnist dataset\n",
    "fashion_mnist = tf.keras.datasets.fashion_mnist\n",
    "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n",
    "\n",
    "train_images = train_images / 255.0\n",
    "test_images = test_images / 255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a simple model using tf.keras\n",
    "keras_model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "    tf.keras.layers.Dense(128, activation='relu'),\n",
    "    tf.keras.layers.Dense(10, activation='softmax')\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples\n",
      "Epoch 1/10\n",
      "60000/60000 [==============================] - 3s 46us/sample - loss: 0.4976 - accuracy: 0.8258\n",
      "Epoch 2/10\n",
      "60000/60000 [==============================] - 2s 39us/sample - loss: 0.3749 - accuracy: 0.8634\n",
      "Epoch 3/10\n",
      "60000/60000 [==============================] - 2s 39us/sample - loss: 0.3377 - accuracy: 0.8774\n",
      "Epoch 4/10\n",
      "60000/60000 [==============================] - 2s 38us/sample - loss: 0.3111 - accuracy: 0.8853\n",
      "Epoch 5/10\n",
      "60000/60000 [==============================] - 2s 38us/sample - loss: 0.2921 - accuracy: 0.8909\n",
      "Epoch 6/10\n",
      "60000/60000 [==============================] - 2s 39us/sample - loss: 0.2788 - accuracy: 0.8960\n",
      "Epoch 7/10\n",
      "60000/60000 [==============================] - 2s 39us/sample - loss: 0.2669 - accuracy: 0.9008\n",
      "Epoch 8/10\n",
      "60000/60000 [==============================] - 2s 40us/sample - loss: 0.2535 - accuracy: 0.9047\n",
      "Epoch 9/10\n",
      "60000/60000 [==============================] - 2s 40us/sample - loss: 0.2442 - accuracy: 0.9080\n",
      "Epoch 10/10\n",
      "60000/60000 [==============================] - 2s 40us/sample - loss: 0.2348 - accuracy: 0.9120\n",
      "10000/1 - 0s - loss: 0.2264 - accuracy: 0.8833\n",
      "\n",
      "Test accuracy: 0.8833\n"
     ]
    }
   ],
   "source": [
    "# training and evaludate keras model\n",
    "keras_model.compile(optimizer='adam',\n",
    "                    loss='sparse_categorical_crossentropy',\n",
    "                    metrics=['accuracy'])\n",
    "\n",
    "keras_model.fit(train_images, train_labels, epochs=10)\n",
    "test_loss, test_acc = keras_model.evaluate(test_images, test_labels, verbose=2)\n",
    "\n",
    "print('\\nTest accuracy:', test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnist_fashion_model.h5\r\n"
     ]
    }
   ],
   "source": [
    "# save the tf.keras model as .h5 model file\n",
    "model_file = './mnist_fashion_model.h5'\n",
    "keras_model.save(model_file)\n",
    "\n",
    "!ls mnist_fashion_model.h5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 assert nodes deleted\n",
      "['sequential/dense_1/BiasAdd/ReadVariableOp/resource:0', 'sequential/dense/MatMul/ReadVariableOp:0', 'sequential/dense/BiasAdd/ReadVariableOp:0', 'sequential/flatten/Reshape/shape:0', 'sequential/dense/BiasAdd/ReadVariableOp/resource:0', 'sequential/dense/MatMul/ReadVariableOp/resource:0', 'sequential/dense_1/MatMul/ReadVariableOp/resource:0', 'sequential/dense_1/BiasAdd/ReadVariableOp:0', 'sequential/dense_1/MatMul/ReadVariableOp:0']\n",
      "4 nodes deleted\n",
      "0 nodes deleted\n",
      "0 nodes deleted\n",
      "[Op Fusion] fuse_bias_add() deleted 4 nodes.\n",
      "2 identity nodes deleted\n",
      "2 disconnected nodes deleted\n",
      "[SSAConverter] Converting function main ...\n",
      "[SSAConverter] [1/7] Converting op type: 'Placeholder', name: 'flatten_input', output_shape: (1, 28, 28).\n",
      "[SSAConverter] [2/7] Converting op type: 'Const', name: 'sequential/flatten/Reshape/shape', output_shape: (2,).\n",
      "[SSAConverter] [3/7] Converting op type: 'Reshape', name: 'sequential/flatten/Reshape', output_shape: (1, 784).\n",
      "[SSAConverter] [4/7] Converting op type: 'MatMul', name: 'sequential/dense/MatMul', output_shape: (1, 128).\n",
      "[SSAConverter] [5/7] Converting op type: 'Relu', name: 'sequential/dense/Relu', output_shape: (1, 128).\n",
      "[SSAConverter] [6/7] Converting op type: 'MatMul', name: 'sequential/dense_1/MatMul', output_shape: (1, 10).\n",
      "[SSAConverter] [7/7] Converting op type: 'Softmax', name: 'Identity', output_shape: (1, 10).\n",
      "[Core ML Pass] 1 disconnected constants nodes deleted\n",
      "mnist_fashion_model.mlmodel\r\n"
     ]
    }
   ],
   "source": [
    "# get input, output node names for the TF graph from the Keras model\n",
    "input_name = keras_model.inputs[0].name.split(':')[0]\n",
    "keras_output_node_name = keras_model.outputs[0].name.split(':')[0]\n",
    "graph_output_node_name = keras_output_node_name.split('/')[-1]\n",
    "\n",
    "# convert this model to Core ML format\n",
    "model = tfcoreml.convert(tf_model_path=model_file,\n",
    "                         input_name_shape_dict={input_name: (1, 28, 28)},\n",
    "                         output_feature_names=[graph_output_node_name],\n",
    "                         minimum_ios_deployment_target='13')\n",
    "model.save('./mnist_fashion_model.mlmodel')\n",
    "\n",
    "!ls mnist_fashion_model.mlmodel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1.5719648e-09 1.7905072e-09 5.9817944e-07 8.1820750e-10 9.6943937e-09\n",
      "  5.0254831e-20 1.5249961e-07 6.2053448e-17 9.9999928e-01 1.0400648e-15]]\n",
      "[[1.57196778e-09 1.79050730e-09 5.98181146e-07 8.18209001e-10\n",
      "  9.69441238e-09 5.02548314e-20 1.52499751e-07 6.20534484e-17\n",
      "  9.99999285e-01 1.04006487e-15]]\n"
     ]
    }
   ],
   "source": [
    "# run predictions with fake image as an input\n",
    "fake_image = np.random.rand(1, 28, 28)\n",
    "\n",
    "keras_predictions = keras_model.predict(fake_image)\n",
    "print(keras_predictions[:10])\n",
    "\n",
    "coreml_predictions = model.predict({'flatten_input': fake_image})['Identity']\n",
    "print(coreml_predictions[:10])\n",
    "\n",
    "assert(np.allclose(keras_predictions, coreml_predictions))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
